import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

def plot_fig(data=np.zeros((0, 0))):
    fig = plt.figure()
    dim = data.shape[1]
    if dim == 2:
        plt.scatter(data[:, 0], data[:, 1])
    else:
        ax = Axes3D(fig)
        ax.scatter3D(data[:, 0], data[:, 1], data[:, 2])
    plt.show()